#' Format MR results for a 1-to-many forest plot
#'
#' This function formats user-supplied results for the [forest_plot_1_to_many()] function.
#' The user supplies their results in the form of a data frame.
#' The data frame is assumed to contain at least three columns of data:
#' \enumerate{
#' \item effect estimates, from an analysis of the effect of an exposure on an outcome;
#' \item standard errors for the effect estimates; and
#' \item a column of trait names, corresponding to the 'many' in a 1-to-many forest plot.
#' }
#'
#' @param mr_res Data frame of results supplied by the user.
#' @param b Name of the column specifying the effect of the exposure on the outcome. Default = `"b"`.
#' @param se Name of the column specifying the standard error for b. Default = `"se"`.
#' @param TraitM The column specifying the names of the traits. Corresponds to 'many' in the 1-to-many forest plot. Default=`"outcome"`.
#' @param addcols Name of any additional columns to add to the plot. Character vector. The default is `NULL`.
#' @param by Name of the column indicating a grouping variable to stratify results on. Default=`NULL`.
#' @param exponentiate Convert log odds ratios to odds ratios? Default=`FALSE`.
#' @param ao_slc Logical; retrieve trait subcategory information using [available_outcomes()]. Default=`FALSE`.
#' @param weight The default is `NULL`.
#'
#' @export
#' @return data frame.
format_1_to_many <- function(
  mr_res,
  b = "b",
  se = "se",
  exponentiate = FALSE,
  ao_slc = FALSE,
  by = NULL,
  TraitM = "outcome",
  addcols = NULL,
  weight = NULL
) {
  if (is.null(by)) {
    mr_res$subcategory <- ""
  }

  if (is.null(weight)) {
    mr_res$weight <- 3
  }

  if (TraitM == "exposure") {
    #the plot function currently tries to plot separate plots for each unique exposure. This is a legacy of the original multiple exposures forest plot function and needs to be cleaned up. The function won't work if the TraitM column is called exposure
    names(mr_res)[names(mr_res) == "exposure"] <- "TraitM"
    TraitM <- "TraitM"
  }

  names(mr_res)[names(mr_res) == b] <- "b"
  names(mr_res)[names(mr_res) == se] <- "se"
  Letters <- LETTERS
  Letters <- sort(c(
    paste0("A", Letters),
    paste0("B", Letters),
    paste0("C", Letters),
    paste0("D", Letters)
  ))
  mr_res$outcome2 <- mr_res[, TraitM]
  mr_res[, TraitM] <- paste(Letters[seq_along(mr_res[, TraitM])], mr_res[, TraitM])

  if (is.null(mr_res$subcategory)) {
    mr_res$subcategory <- ""
  }
  mr_res$subcategory <- trim(mr_res$subcategory)
  mr_res$exposure <- ""

  # Get extra info on outcomes
  if (ao_slc) {
    ao <- available_outcomes()
    ao$subcategory[ao$subcategory == "Cardiovascular"] <- "Cardiometabolic"
    ao$subcategory[ao$trait == "Type 2 diabetes"] <- "Cardiometabolic"
    names(ao)[names(ao) == "nsnp"] <- "nsnp.array"
  }

  dat <- mr_res
  dat$index <- seq_len(nrow(dat))

  if (ao_slc) {
    dat <- merge(dat, ao, by.x = "id.outcome", by.y = "id")
  }
  dat <- dat[order(dat$b), ]

  # Create CIs
  dat$up_ci <- as.numeric(dat$b) + 1.96 * as.numeric(dat$se)
  dat$lo_ci <- as.numeric(dat$b) - 1.96 * as.numeric(dat$se)

  # Exponentiate?
  if (exponentiate) {
    dat$b <- exp(as.numeric(dat$b))
    dat$up_ci <- exp(dat$up_ci)
    dat$lo_ci <- exp(dat$lo_ci)
  }

  # Organise cats
  dat$subcategory <- as.factor(dat$subcategory)

  #generate a simple trait column. this contains only the outcome name (ie excludes consortium and year from the outcome column generated by mr()). This step caters to the possibility that a user's results contain a mixture of results obtained via MR-Base and correspondence. The later won't be present in the MR-Base database. However, still need to split the outcome name into trait, year and consortium.
  if (!ao_slc) {
    dat$trait <- as.character(dat[, TraitM])
    Pos <- grep("\\|\\|", dat$trait) #this indicates the outcome column was derived from data in MR-Base. Sometimes it wont look like this e.g. if the user has supplied their own outcomes
    if (sum(Pos) != 0) {
      Outcome <- dat$trait[Pos]
      Outcome <- unlist(strsplit(Outcome, split = "\\|\\|"))
      Outcome <- Outcome[seq(1, length(Outcome), by = 2)]
      Outcome <- trim(Outcome)
      dat$trait[Pos] <- Outcome
    }
  }

  dat1 <- data.frame(
    exposure = as.character(dat$exposure),
    outcome = as.character(dat$trait),
    outcome2 = as.character(dat$outcome2),
    category = as.character(dat$subcategory),
    effect = dat$b,
    se = dat$se,
    up_ci = dat$up_ci,
    lo_ci = dat$lo_ci,
    index = dat$index,
    weight = dat$weight,
    stringsAsFactors = FALSE
  )

  if (!is.null(addcols)) {
    dat2 <- dat[, addcols]
    dat <- cbind(dat1, dat2)
    if (length(addcols) == 1) {
      names(dat)[names(dat) == "dat2"] <- addcols
    }
  } else {
    dat <- dat1
  }

  exps <- unique(dat$exposure)

  dat <- dat[order(dat$index), ]

  dat <- dat[order(dat$outcome), ]

  return(dat)
}

#' Sort results for 1-to-many forest plot
#'
#' This function sorts user-supplied results for the [forest_plot_1_to_many()] function. The user supplies their results in the form of a data frame.
#'
#' @param mr_res Data frame of results supplied by the user.
#' @param b Name of the column specifying the effect of the exposure on the outcome. The default is `"b"`.
#' @param trait_m The column specifying the names of the traits. Corresponds to 'many' in the 1-to-many forest plot. The default is `"outcome"`.
#' @param group Name of grouping variable in `mr_res`.
#' @param priority If `sort_action = 3`, choose which value of the `trait_m` variable should be given priority and go above the other `trait_m` values.
#' The trait with the largest effect size for the prioritised group will go to the top of the plot.
#' @param sort_action Choose how to sort results.
#' \itemize{
#' \item `sort_action = 1`: sort results by effect size within groups. Use the group order supplied by the user.
#' \item `sort_action = 2`: sort results by effect size and group. Overrides the group ordering supplied by the user.
#' \item `sort_action = 3`: group results for the same trait together (e.g. multiple results for the same trait from different MR methods).
#' \item `sort_action = 4`: sort by decreasing effect size (largest effect size at top and smallest at bottom).
#' \item `sort_action = 5`: sort by increasing effect size (smallest effect size at top and largest at bottom).
#' }
#'
#' @export
#' @return data frame.
#'
sort_1_to_many <- function(
  mr_res,
  b = "b",
  trait_m = "outcome",
  sort_action = 4,
  group = NULL,
  priority = NULL
) {
  mr_res[, trait_m] <- as.character(mr_res[, trait_m])
  mr_res[, group] <- as.character(mr_res[, group])
  if (!b %in% names(mr_res)) {
    warning(
      "Column with effect estimates not found. Did you forget to specify the column of data containing your effect estimates?"
    )
  }
  if (sort_action == 1) {
    if (is.null(group)) {
      warning("You must indicate a grouping variable")
    }

    # Numbers<-1:100
    Letters <- c(
      "A",
      "B",
      "C",
      "D",
      "E",
      "F",
      "G",
      "H",
      "I",
      "J",
      "K",
      "L",
      "M",
      "N",
      "O",
      "P",
      "Q",
      "R",
      "S",
      "T",
      "U",
      "V",
      "W",
      "X",
      "Y",
      "Z"
    )
    Letters <- sort(c(paste0("A", Letters), paste0("B", Letters), paste0("C", Letters)))
    groups <- unique(mr_res[, group])
    mr_res$Index <- unlist(lapply(seq_along(unique(mr_res[, group])), FUN = function(x) {
      rep(Letters[Letters == Letters[x]], length(which(mr_res[, group] == groups[x])))
    }))
    mr_res <- mr_res[order(mr_res[, b], decreasing = TRUE), ]
    mr_res$Index2 <- Letters[seq_len(nrow(mr_res))]
    mr_res$Index3 <- paste(mr_res$Index, mr_res$Index2, sep = "")
    mr_res <- mr_res[order(mr_res$Index3), ]
    mr_res <- mr_res[, !names(mr_res) %in% c("Index", "Index2", "Index3")]
  }

  if (sort_action == 2) {
    if (is.null(group)) {
      warning("You must indicate a grouping variable")
    }
    mr_res <- mr_res[order(mr_res[, b], decreasing = TRUE), ]
    mr_res <- mr_res[order(mr_res[, group]), ]
  }

  if (sort_action == 3) {
    if (is.null(group)) {
      warning("You must indicate a grouping variable")
    }
    if (is.null(priority)) {
      warning(
        "You must indicate which value of the grouping variable ",
        group,
        " to use as the priority value"
      )
    }

    mr_res$b.sort <- NA
    mr_res1 <- mr_res[mr_res[, group] %in% mr_res[, group][duplicated(mr_res[, group])], ]
    mr_res2 <- mr_res[!mr_res[, group] %in% mr_res[, group][duplicated(mr_res[, group])], ]

    mr_res1$b.sort[mr_res1[, trait_m] == priority] <- mr_res1[, b][mr_res1[, trait_m] == priority]
    # mr_res1$b.sort[mr_res1[,group]==priority]<-1000
    for (i in unique(mr_res1[, group])) {
      mr_res1$b.sort[mr_res1[, group] == i & is.na(mr_res1$b.sort)] <- mr_res1$b.sort[
        mr_res1[, group] == i & !is.na(mr_res1$b.sort)
      ]
    }
    # mr_res1$b.sort[is.na(mr_res1$b.sort)]<-mr_res1$b.sort[!is.na(mr_res1$b.sort)]
    mr_res2$b.sort <- mr_res2$b
    mr_res <- rbind(mr_res1, mr_res2)

    mr_res <- mr_res[order(mr_res$b.sort, decreasing = TRUE), ]
    groups <- unique(mr_res[, group])
    List <- NULL
    for (i in seq_along(groups)) {
      Test <- mr_res[mr_res[, group] == groups[i], ]
      Test1 <- Test[Test[, trait_m] != priority, ]
      Test2 <- Test[Test[, trait_m] == priority, ]
      List[[i]] <- rbind(Test2, Test1)
    }
    mr_res <- do.call(rbind, List)
  }

  if (sort_action == 4) {
    mr_res <- mr_res[order(mr_res[, b], decreasing = TRUE), ]
  }

  if (sort_action == 5) {
    mr_res <- mr_res[order(mr_res[, b], decreasing = FALSE), ]
  }

  return(mr_res)
}

#' A basic forest plot
#'
#' This function is used to create a basic forest plot.
#' It requires the output from [format_1_to_many()].
#'
#' @param dat Output from [format_1_to_many()]
#' @param section Which category in dat to plot. If `NULL` then prints everything.
#' @param colour_group Which exposure to plot. If `NULL` then prints everything grouping by colour.
#' @param colour_group_first The default is `TRUE`.
#' @param xlab x-axis label. Default=`NULL`.
#' @param bottom Show x-axis? Default=`FALSE`.
#' @param trans x-axis scale.
#' @param xlim x-axis limits.
#' @param lo Lower limit of x axis.
#' @param up Upper limit of x axis.
#' @param subheading_size text size for the subheadings. The subheadings correspond to the values of the section argument.
#' @param colour_scheme the general colour scheme for the plot. Default is to make all text and data points `"black"`.
#' @param shape_points the shape of the data points to pass to [ggplot2::geom_point()]. Default is set to `15` (filled square).
#'
#' @return ggplot object
forest_plot_basic2 <- function(
  dat,
  section = NULL,
  colour_group = NULL,
  colour_group_first = TRUE,
  xlab = NULL,
  bottom = TRUE,
  trans = "identity",
  xlim = NULL,
  lo = lo,
  up = up,
  subheading_size = subheading_size,
  colour_scheme = "black",
  shape_points = 15
) {
  if (bottom) {
    text_colour <- ggplot2::element_text(colour = "black")
    tick_colour <- ggplot2::element_line(colour = "black")
    xlabname <- xlab
  } else {
    text_colour <- ggplot2::element_blank()
    tick_colour <- ggplot2::element_blank()
    xlabname <- NULL
  }

  # OR or log(OR)?
  # If CI are symmetric then log(OR)
  # Use this to guess where to put the null line
  null_line <- ifelse(all.equal(dat$effect - dat$lo_ci, dat$up_ci - dat$effect) == TRUE, 0, 1)

  # Change lab
  if (!is.null(xlim)) {
    stopifnot(length(xlim) == 2)
    stopifnot(xlim[1] < xlim[2])
    dat$lo_ci <- pmax(dat$lo_ci, xlim[1], na.rm = TRUE)
    dat$up_ci <- pmin(dat$up_ci, xlim[2], na.rm = TRUE)
  }

  if (is.null(up) || is.null(lo)) {
    up <- max(dat$up_ci, na.rm = TRUE)
    lo <- min(dat$lo_ci, na.rm = TRUE)
  }
  r <- up - lo
  lo_orig <- lo
  lo <- lo - r * 0.5

  if (!is.null(section)) {
    dat <- subset(dat, category == section)
    main_title <- section
  } else {
    main_title <- ""
  }

  if (!is.null(colour_group)) {
    dat <- subset(dat, exposure == colour_group)
    point_plot <- ggplot2::geom_point(
      size = dat$weight,
      colour = colour_scheme,
      fill = colour_scheme,
      shape = shape_points
    )
  } else {
    point_plot <- ggplot2::geom_point(
      ggplot2::aes(colour = colour_scheme),
      size = dat$weight,
      fill = colour_scheme
    )
  }

  if ((!is.null(colour_group) && colour_group_first) || is.null(colour_group)) {
    outcome_labels <- ggplot2::geom_text(
      ggplot2::aes(label = outcome2, colour = colour_scheme),
      x = lo,
      y = mean(c(1, length(unique(dat$exposure)))),
      hjust = 0,
      vjust = 0.5,
      size = 2.5
    )
    main_title <- ifelse(is.null(section), "", section)
    title_colour <- "black"
  } else {
    outcome_labels <- NULL
    lo <- lo_orig
    main_title <- ""
    title_colour <- "white"
  }

  main_title <- section

  dat$lab <- dat$outcome
  l <- data.frame(lab = sort(unique(dat$lab)), col = "a", stringsAsFactors = FALSE)
  l$col[seq_len(nrow(l)) %% 2 == 0] <- "b"

  dat <- merge(dat, l, by = "lab", all.x = TRUE)
  dat <- dat[rev(seq_len(nrow(dat))), ]

  if (utils::packageVersion("ggplot2") <= "3.5.2") {
    p <- ggplot2::ggplot(dat, ggplot2::aes(x = effect, y = exposure)) +
      ggplot2::geom_rect(
        ggplot2::aes(fill = col),
        colour = colour_scheme,
        xmin = -Inf,
        xmax = Inf,
        ymin = -Inf,
        ymax = Inf
      ) +
      ggplot2::geom_vline(
        xintercept = seq(ceiling(lo_orig), ceiling(up), by = 0.5),
        alpha = 0,
        linewidth = 0.3
      ) +
      ggplot2::geom_vline(xintercept = null_line, colour = "#333333", linewidth = 0.3) +
      # ggplot2::geom_errorbarh(ggplot2::aes(xmin=lo_ci, xmax=up_ci), height=0, size=0.4, colour="#aaaaaa") +
      ggplot2::geom_errorbarh(
        ggplot2::aes(xmin = lo_ci, xmax = up_ci),
        height = 0,
        linewidth = 0.4,
        colour = colour_scheme
      ) +
      # ggplot2::geom_point(colour="black", size=2.2) +
      ggplot2::geom_point(
        colour = colour_scheme,
        size = 2.2,
        shape = shape_points,
        fill = colour_scheme
      ) +
      # ggplot2::scale_fill_manual(values="cyan4")+
      point_plot +
      ggplot2::facet_grid(lab ~ .) +
      ggplot2::scale_x_continuous(trans = trans, limits = c(lo, up)) +
      ggplot2::scale_colour_brewer(type = "qual") +
      # ggplot2::scale_fill_manual(values=c("#eeeeee", "#ffffff"), guide=FALSE) +
      ggplot2::scale_fill_manual(values = c("#eeeeee", "#ffffff"), guide = "none") +
      ggplot2::theme(
        axis.line = ggplot2::element_blank(),
        axis.text.y = ggplot2::element_blank(),
        axis.ticks.y = ggplot2::element_blank(),
        axis.text.x = text_colour,
        axis.ticks.x = tick_colour,
        # strip.text.y=ggplot2::element_text(angle=360, hjust=0),
        strip.background = ggplot2::element_rect(fill = "white", colour = "white"),
        strip.text = ggplot2::element_text(family = "Courier New", face = "bold", size = 9),
        legend.position = "none",
        legend.direction = "vertical",
        panel.grid.minor.x = ggplot2::element_blank(),
        panel.grid.minor.y = ggplot2::element_blank(),
        panel.grid.major.y = ggplot2::element_blank(),
        plot.title = ggplot2::element_text(
          hjust = 0,
          size = subheading_size,
          colour = title_colour
        ),
        plot.margin = ggplot2::unit(c(2, 3, 2, 0), units = "points"),
        plot.background = ggplot2::element_rect(fill = "white"),
        panel.spacing = ggplot2::unit(0, "lines"),
        panel.background = ggplot2::element_rect(
          colour = "white",
          fill = colour_scheme,
          linewidth = 1
        ),
        strip.text.y = ggplot2::element_blank()
        # strip.background = ggplot2::element_blank()
      ) +
      ggplot2::labs(y = NULL, x = xlabname, colour = NULL, fill = NULL, title = main_title) +
      outcome_labels
  } else {
    p <- ggplot2::ggplot(dat, ggplot2::aes(x = effect, y = exposure)) +
      ggplot2::geom_rect(
        ggplot2::aes(fill = col),
        colour = colour_scheme,
        xmin = -Inf,
        xmax = Inf,
        ymin = -Inf,
        ymax = Inf
      ) +
      ggplot2::geom_vline(
        xintercept = seq(ceiling(lo_orig), ceiling(up), by = 0.5),
        alpha = 0,
        linewidth = 0.3
      ) +
      ggplot2::geom_vline(xintercept = null_line, colour = "#333333", linewidth = 0.3) +
      # ggplot2::geom_errorbarh(ggplot2::aes(xmin=lo_ci, xmax=up_ci), height=0, size=0.4, colour="#aaaaaa") +
      ggplot2::geom_errorbar(
        ggplot2::aes(xmin = lo_ci, xmax = up_ci),
        width = 0,
        linewidth = 0.4,
        colour = colour_scheme,
        orientation = "y"
      ) +
      # ggplot2::geom_point(colour="black", size=2.2) +
      ggplot2::geom_point(
        colour = colour_scheme,
        size = 2.2,
        shape = shape_points,
        fill = colour_scheme
      ) +
      # ggplot2::scale_fill_manual(values="cyan4")+
      point_plot +
      ggplot2::facet_grid(lab ~ .) +
      ggplot2::scale_x_continuous(trans = trans, limits = c(lo, up)) +
      ggplot2::scale_colour_brewer(type = "qual") +
      # ggplot2::scale_fill_manual(values=c("#eeeeee", "#ffffff"), guide=FALSE) +
      ggplot2::scale_fill_manual(values = c("#eeeeee", "#ffffff"), guide = "none") +
      ggplot2::theme(
        axis.line = ggplot2::element_blank(),
        axis.text.y = ggplot2::element_blank(),
        axis.ticks.y = ggplot2::element_blank(),
        axis.text.x = text_colour,
        axis.ticks.x = tick_colour,
        # strip.text.y=ggplot2::element_text(angle=360, hjust=0),
        strip.background = ggplot2::element_rect(fill = "white", colour = "white"),
        strip.text = ggplot2::element_text(family = "Courier New", face = "bold", size = 9),
        legend.position = "none",
        legend.direction = "vertical",
        panel.grid.minor.x = ggplot2::element_blank(),
        panel.grid.minor.y = ggplot2::element_blank(),
        panel.grid.major.y = ggplot2::element_blank(),
        plot.title = ggplot2::element_text(
          hjust = 0,
          size = subheading_size,
          colour = title_colour
        ),
        plot.margin = ggplot2::unit(c(2, 3, 2, 0), units = "points"),
        plot.background = ggplot2::element_rect(fill = "white"),
        panel.spacing = ggplot2::unit(0, "lines"),
        panel.background = ggplot2::element_rect(
          colour = "white",
          fill = colour_scheme,
          linewidth = 1
        ),
        strip.text.y = ggplot2::element_blank()
        # strip.background = ggplot2::element_blank()
      ) +
      ggplot2::labs(y = NULL, x = xlabname, colour = NULL, fill = NULL, title = main_title) +
      outcome_labels
  }
  return(p)
}


forest_plot_names2 <- function(
  dat,
  section = NULL,
  var1 = "outcome2",
  bottom = TRUE,
  title = "",
  subheading_size = subheading_size,
  colour_scheme = "black",
  shape_points = 15,
  col_text_size = 5
) {
  if (bottom) {
    text_colour <- ggplot2::element_text(colour = "white")
    tick_colour <- ggplot2::element_line(colour = "white")
    xlabname <- ""
  } else {
    text_colour <- ggplot2::element_blank()
    tick_colour <- ggplot2::element_blank()
    xlabname <- NULL
  }

  # OR or log(OR)?
  # If CI are symmetric then log(OR)
  # Use this to guess where to put the null line
  null_line <- ifelse(all.equal(dat$effect - dat$lo_ci, dat$up_ci - dat$effect) == TRUE, 0, 1)

  # up <- max(dat$up_ci, na.rm=TRUE)
  # lo <- min(dat$lo_ci, na.rm=TRUE)
  # r <- up-lo
  # lo_orig <- lo
  # lo <- lo - r * 0.5
  lo <- 0
  up <- 1

  if (!is.null(section)) {
    dat <- subset(dat, category == section)
    main_title <- section
    section_colour <- "black"
  } else {
    main_title <- section
    section_colour <- "white"
  }

  point_plot <- ggplot2::geom_point(ggplot2::aes(colour = exposure), size = 2)

  outcome_labels <- ggplot2::geom_text(
    ggplot2::aes(label = eval(parse(text = var1))),
    x = lo,
    y = mean(c(1, length(unique(dat$exposure)))),
    hjust = 0,
    vjust = 0.5,
    size = col_text_size,
    color = colour_scheme
  )

  # print(paste0("title=",title))
  if (section == "") {
    main_title <- title
  }

  dat$lab <- dat$outcome
  l <- data.frame(lab = sort(unique(dat$lab)), col = "a", stringsAsFactors = FALSE)

  l$col[seq_len(nrow(l)) %% 2 == 0] <- "b"

  dat <- merge(dat, l, by = "lab", all.x = TRUE)

  p <- ggplot2::ggplot(dat, ggplot2::aes(x = effect, y = exposure)) +
    ggplot2::geom_rect(
      ggplot2::aes(fill = col),
      colour = colour_scheme,
      xmin = -Inf,
      xmax = Inf,
      ymin = -Inf,
      ymax = Inf
    ) +
    ggplot2::facet_grid(lab ~ .) +
    ggplot2::scale_x_continuous(limits = c(lo, up)) +
    ggplot2::scale_colour_brewer(type = "qual") +
    ggplot2::scale_fill_manual(values = c("#eeeeee", "#ffffff"), guide = "none") +
    ggplot2::theme(
      axis.line = ggplot2::element_blank(),
      axis.text.y = ggplot2::element_blank(),
      axis.ticks.y = ggplot2::element_blank(),
      axis.text.x = text_colour,
      axis.ticks.x = tick_colour,
      # strip.text.y=ggplot2::element_text(angle=360, hjust=0),
      strip.background = ggplot2::element_rect(fill = "white", colour = "white"),
      strip.text = ggplot2::element_text(family = "Courier New", face = "bold", size = 11),
      legend.position = "none",
      legend.direction = "vertical",
      panel.grid.minor.x = ggplot2::element_blank(),
      panel.grid.minor.y = ggplot2::element_blank(),
      panel.grid.major.y = ggplot2::element_blank(),
      plot.title = ggplot2::element_text(
        hjust = 0,
        size = subheading_size,
        colour = section_colour
      ),
      plot.margin = ggplot2::unit(c(2, 0, 2, 0), units = "points"),
      plot.background = ggplot2::element_rect(fill = "white"),
      panel.spacing = ggplot2::unit(0, "lines"),
      panel.background = ggplot2::element_rect(
        colour = colour_scheme,
        fill = colour_scheme,
        linewidth = 1
      ),
      strip.text.y = ggplot2::element_blank()
      # strip.background = ggplot2::element_blank()
    ) +
    ggplot2::labs(y = NULL, x = xlabname, colour = NULL, fill = NULL, title = main_title) +
    outcome_labels
  return(p)
}


forest_plot_addcol <- function(
  dat,
  section = NULL,
  addcol = NULL,
  bottom = TRUE,
  addcol_title = NULL,
  subheading_size = subheading_size,
  colour_scheme = "black",
  shape_points = 15,
  col_text_size = 5
) {
  print(addcol)
  # print(addcol_title)
  if (bottom) {
    text_colour <- ggplot2::element_text(colour = "white")
    tick_colour <- ggplot2::element_line(colour = "white")
    xlabname <- ""
  } else {
    text_colour <- ggplot2::element_blank()
    tick_colour <- ggplot2::element_blank()
    xlabname <- NULL
  }

  # OR or log(OR)?
  # If CI are symmetric then log(OR)
  # Use this to guess where to put the null line
  null_line <- ifelse(all.equal(dat$effect - dat$lo_ci, dat$up_ci - dat$effect) == TRUE, 0, 1)

  lo <- 0
  up <- 1

  if (!is.null(section)) {
    dat <- subset(dat, category == section)
    main_title <- section
    section_colour <- "black"
  } else {
    main_title <- section
    section_colour <- "white"
  }

  point_plot <- ggplot2::geom_point(ggplot2::aes(colour = exposure), size = 2)

  outcome_labels <- ggplot2::geom_text(
    ggplot2::aes(label = eval(parse(text = addcol))),
    x = lo,
    y = mean(c(1, length(unique(dat$exposure)))),
    hjust = 0,
    vjust = 0.5,
    size = col_text_size,
    colour = colour_scheme
  )

  main_title <- section

  dat$lab <- dat$outcome
  l <- data.frame(lab = sort(unique(dat$lab)), col = "a", stringsAsFactors = FALSE)
  l$col[seq_len(nrow(l)) %% 2 == 0] <- "b"

  dat <- merge(dat, l, by = "lab", all.x = TRUE)

  p <- ggplot2::ggplot(dat, ggplot2::aes(x = effect, y = exposure)) +
    ggplot2::geom_rect(
      ggplot2::aes(fill = col),
      colour = colour_scheme,
      xmin = -Inf,
      xmax = Inf,
      ymin = -Inf,
      ymax = Inf
    ) +
    ggplot2::facet_grid(lab ~ .) +
    ggplot2::scale_x_continuous(limits = c(lo, up)) +
    ggplot2::scale_colour_brewer(type = "qual") +
    ggplot2::scale_fill_manual(values = c("#eeeeee", "#ffffff"), guide = "none") +
    ggplot2::theme(
      axis.line = ggplot2::element_blank(),
      axis.text.y = ggplot2::element_blank(),
      axis.ticks.y = ggplot2::element_blank(),
      axis.text.x = text_colour,
      axis.ticks.x = tick_colour,
      # strip.text.y=ggplot2::element_text(angle=360, hjust=0),
      strip.background = ggplot2::element_rect(fill = "white", colour = "white"),
      strip.text = ggplot2::element_text(family = "Courier New", face = "bold", size = 11),
      legend.position = "none",
      legend.direction = "vertical",
      panel.grid.minor.x = ggplot2::element_blank(),
      panel.grid.minor.y = ggplot2::element_blank(),
      panel.grid.major.y = ggplot2::element_blank(),
      plot.title = ggplot2::element_text(
        hjust = 0,
        size = subheading_size,
        colour = section_colour
      ),
      plot.margin = ggplot2::unit(c(2, 0, 2, 0), units = "points"),
      plot.background = ggplot2::element_rect(fill = "white"),
      panel.spacing = ggplot2::unit(0, "lines"),
      panel.background = ggplot2::element_rect(colour = "red", fill = colour_scheme, linewidth = 1),
      strip.text.y = ggplot2::element_blank(),
      strip.text.x = ggplot2::element_blank()
      # strip.background = ggplot2::element_blank()
    ) +
    ggplot2::labs(y = NULL, x = xlabname, colour = NULL, fill = NULL, title = addcol_title) +
    outcome_labels
  return(p)
}

#' 1-to-many forest plot
#'
#' Plot results from an analysis of multiple exposures against a single outcome or a single exposure against multiple outcomes.
#' Plots effect estimates and 95 percent confidence intervals.
#' The ordering of results in the plot is determined by the order supplied by the user.
#' Users may find [sort_1_to_many()] helpful for sorting their results prior to using the 1-to-many forest plot. The plot function works best for 50 results and is not designed to handle more than 100 results.
#'
#' @param mr_res Data frame of results supplied by the user. The default is `"mr_res"`.
#' @param b Name of the column specifying the effect of the exposure on the outcome. The default is `"b"`.
#' @param se Name of the column specifying the standard error for b. The default is `"se"`.
#' @param TraitM The column specifying the names of the traits. Corresponds to 'many' in the 1-to-many forest plot. The default is `"outcome"`.
#' @param col1_title Title for the column specified by the TraitM argument. The default is `""`.
#' @param col1_width Width of Y axis label for the column specified by the TraitM argument. The default is `1`.
#' @param addcols Name of additional columns to plot. Character vector. The default is `NULL`.
#' @param addcol_titles Titles of additional columns specified by the addcols argument. Character vector. The default is `""`.
#' @param addcol_widths Widths of Y axis labels for additional columns specified by the addcols argument. Numeric vector. The default is `NULL`.
#' @param xlab X-axis label, default is `"Effect (95% confidence interval)"`.
#' @param by Name of the grouping variable to stratify results on. Default is `NULL`.
#' @param subheading_size text size for the subheadings specified in by argument. The default is `6`.
#' @param exponentiate Convert log odds ratios to odds ratios? Default is `FALSE`.
#' @param ao_slc Logical; retrieve trait subcategory information using available_outcomes(). Default is `FALSE`.
#' @param trans Specify x-axis scale. e.g. "identity", "log2", etc. If set to "identity" an additive scale is used. If set to log2 the x-axis is plotted on a multiplicative / doubling scale (preferable when plotting odds ratios). Default is `"identity"`.
#' @param lo Lower limit of X axis to plot.
#' @param up upper limit of X axis to plot.
#' @param colour_scheme the general colour scheme for the plot. Default is to make all text and data points `"black"`.
#' @param shape_points the shape of the data points to pass to [ggplot2::geom_point()]. Default is set to `15` (filled square).
#' @param col_text_size The default is `5`.
#' @param weight The default is `NULL`.
#'
#' @export
#' @return grid plot object
forest_plot_1_to_many <- function(
  mr_res = "mr_res",
  b = "b",
  se = "se",
  TraitM = "outcome",
  col1_width = 1,
  col1_title = "",
  exponentiate = FALSE,
  trans = "identity",
  ao_slc = TRUE,
  lo = NULL,
  up = NULL,
  by = NULL,
  xlab = "Effect (95% confidence interval)",
  addcols = NULL,
  addcol_widths = NULL,
  addcol_titles = "",
  subheading_size = 6,
  shape_points = 15,
  colour_scheme = "black",
  col_text_size = 5,
  weight = NULL
) {
  # if (is.null(lo) | is.null(up)) warning("Values missing for the lower or upper bounds of the x axis. Did you forget to set the lo and up arguments?")

  xlim <- NULL
  ncols <- 1 + length(addcols)
  if (all(addcol_titles == "")) {
    addcol_titles <- rep(addcol_titles, length(addcols))
  }

  dat <- format_1_to_many(
    mr_res = mr_res,
    b = b,
    se = se,
    exponentiate = exponentiate,
    ao_slc = ao_slc,
    by = by,
    TraitM = TraitM,
    addcols = addcols,
    weight = weight
  )

  legend <- cowplot::get_legend(
    ggplot2::ggplot(dat, ggplot2::aes(x = effect, y = outcome)) +
      ggplot2::geom_point(ggplot2::aes(colour = exposure)) +
      ggplot2::scale_colour_brewer(type = "qual") +
      ggplot2::labs(colour = "Exposure") +
      ggplot2::theme(text = ggplot2::element_text(size = 10))
  )

  # message("howzit, may all your scripts be up-to-date and well annotated")
  if (length(addcols) != length(addcol_widths)) {
    warning("length of addcols not equal to length of addcol_widths")
  }
  sec <- unique(as.character(dat$category))
  columns <- unique(dat$exposure)
  l <- list()
  h <- rep(0, length(sec))
  count <- 1
  for (i in seq_along(sec)) {
    h[i] <- length(unique(subset(dat, category == sec[i])$outcome))

    l[[count]] <- forest_plot_names2(
      dat,
      sec[i],
      bottom = i == length(sec),
      title = col1_title,
      subheading_size = subheading_size,
      colour_scheme = colour_scheme,
      shape_points = shape_points,
      col_text_size = col_text_size
    )

    count <- count + 1

    if (!is.null(addcols)) {
      for (j in seq_along(addcols)) {
        l[[count]] <- forest_plot_addcol(
          dat,
          sec[i],
          addcol = addcols[j],
          addcol_title = addcol_titles[j],
          bottom = i == length(sec),
          subheading_size = subheading_size,
          colour_scheme = colour_scheme,
          shape_points = shape_points,
          col_text_size = col_text_size
        )

        count <- count + 1
      }
    }

    for (j in seq_along(columns)) {
      l[[count]] <- forest_plot_basic2(
        dat,
        sec[i],
        bottom = i == length(sec),
        colour_group = columns[j],
        colour_group_first = FALSE,
        xlab = paste0(xlab, " ", columns[j]),
        lo = lo,
        up = up,
        trans = trans,
        xlim = xlim,
        subheading_size = subheading_size,
        colour_scheme = colour_scheme,
        shape_points = shape_points
      )
      count <- count + 1
    }
  }
  h <- h + 5
  h[length(sec)] <- h[length(sec)] + 1
  return(
    cowplot::plot_grid(
      gridExtra::arrangeGrob(
        grobs = l,
        ncol = length(columns) + ncols,
        nrow = length(h),
        heights = h,
        widths = c(col1_width, addcol_widths, rep(5, length(columns)))
      )
    )
  )
}
